# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import math
import unittest

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

import jax
import jax.numpy as jnp
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import deprecations
from jax._src import op_shardings
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.util import safe_zip
from jax._src.mesh import AxisType, AbstractMesh, Mesh
from jax._src.sharding import common_devices_indices_map
from jax._src.sharding_impls import (
    pmap_sharding_devices_indices_map, NamedSharding, GSPMDSharding)
from jax.experimental import multihost_utils
from jax.sharding import PartitionSpec as P
from jax._src import array
from jax._src import prng

jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(8)

with contextlib.suppress(ImportError):
  import pytest
  pytestmark = pytest.mark.multiaccelerator


def create_array(shape, sharding, global_data=None):
  if global_data is None:
    global_data = np.arange(math.prod(shape)).reshape(shape)

  return array.make_array_from_callback(
      shape, sharding, lambda idx: global_data[idx]), global_data


class JaxArrayTest(jtu.JaxTestCase):

  def test_array_impl_name(self):
    self.assertEqual(array.ArrayImpl.__name__, "ArrayImpl")

  @parameterized.named_parameters(
      ("mesh_x_y", P("x", "y")),
      ("mesh_x", P("x")),
      ("mesh_y", P("y")),
      ("mesh_none_y", P(None, "y")),
      ("mesh_xy", P(("x", "y"))),
      ("mesh_fully_replicated", P()),
  )
  def test_jax_array_value(self, mesh_axes):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    arr, global_data = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes))
    for s in arr.addressable_shards:
      self.assertTrue(dispatch.is_single_device_sharding(s.data.sharding))
      self.assertArraysEqual(s.data, global_data[s.index])
    self.assertArraysEqual(arr._value, global_data)
    if arr._npy_value is not None:
      self.assertArraysEqual(arr._npy_value, global_data)

  @parameterized.named_parameters(
      ("mesh_x_y", P("x", "y"),
       # There are more slices but for convenient purposes, checking for only
       # 2. The indices + shard_shape + replica_id should be unique enough.
       ((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
       (2, 1),
       [0, 0, 0, 0, 0, 0, 0, 0], False),
      ("mesh_x", P("x"),
       ((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
       (2, 2),
       [0, 1, 0, 1, 0, 1, 0, 1], False),
      ("mesh_y", P("y"),
       ((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
       (4, 2),
       [0, 0, 1, 1, 2, 2, 3, 3], False),
      ("mesh_none_y", P(None, "y"),
       ((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
       (8, 1),
       [0, 0, 1, 1, 2, 2, 3, 3], False),
      ("mesh_xy", P(("x", "y")),
       ((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
       (1, 2),
       [0, 0, 0, 0, 0, 0, 0, 0], False),
      ("mesh_fully_replicated", P(),
       ((slice(None), slice(None)), (slice(None), slice(None))),
       (8, 2),
       [0, 1, 2, 3, 4, 5, 6, 7], True),
  )
  def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                          expected_replica_ids, expected_is_fully_replicated):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True)
    global_input_shape = (8, 2)
    s = jax.sharding.NamedSharding(global_mesh, mesh_axes)
    arr, global_input_data = create_array(global_input_shape, s)
    self.assertEqual(arr.ndim, 2)
    self.assertEqual(arr.size, 16)
    self.assertEqual(arr.addressable_shards[0].index, expected_index[0])
    self.assertEqual(arr.addressable_shards[1].index, expected_index[1])
    replica_ids = [i.replica_id for i in arr.addressable_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)
    self.assertListEqual([i.device.id for i in arr.addressable_shards],
                         [0, 1, 2, 3, 4, 5, 6, 7])
    self.assertEqual(arr.is_fully_replicated, expected_is_fully_replicated)
    for i, s in enumerate(arr.addressable_shards):
      self.assertEqual(s.data.aval,
                       core.ShapedArray(expected_shard_shape, s.data.dtype))
      self.assertArraysEqual(s.data, global_input_data[s.index])
      self.assertArraysEqual(s.data, arr.addressable_data(i))

    for g, l in safe_zip(arr.global_shards, arr.addressable_shards):
      self.assertEqual(g.device, l.device)
      self.assertEqual(g.index, l.index)
      self.assertEqual(g.replica_id, l.replica_id)
      self.assertEqual(g.data.aval, l.data.aval)
      self.assertArraysEqual(g.data, l.data)

  def test_addressable_data(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    shape = (8, 2)
    s = jax.sharding.NamedSharding(global_mesh, P(None))
    arr, inp_data = create_array(shape, s)
    for i in range(len(arr)):
      self.assertArraysEqual(inp_data, arr.addressable_data(i))

  def test_array_delete(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    arr, _ = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
    arr.delete()
    with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
      arr._check_if_deleted()
    self.assertIsNone(arr._npy_value)
    self.assertIsNone(arr._arrays)

  def test_single_device_array_usage_after_delete(self):
    x = jnp.array([1, 2, 3])
    x.delete()

    with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
      _ = x + 1

  @parameterized.named_parameters(
      ('no_global_shape', np.arange(10).reshape(2, 5), None),
      ('global_shape_prefix', {'a': np.arange(10).reshape(2, 5)}, (2, 5)),
      ('global_shape_full', {'a': np.arange(10).reshape(2, 5)}, {'a': (2, 5)}),
  )
  def test_array_from_local_data_single_host(self, data, global_shape):
    jnp_data = jax.make_array_from_process_local_data(
        jax.devices()[0], data, global_shape
    )
    jax.tree.map(self.assertArraysEqual, data, jnp_data)

  @parameterized.named_parameters(
      ('global_shape_prefix', {'a': np.arange(10).reshape(2, 5)}, (2, 8)),
      ('global_shape_full', {'a': np.arange(10).reshape(2, 5)}, {'a': (2, 6)}),
      (
          'global_shape_extra',
          {'a': np.arange(10).reshape(2, 5)},
          {'a': (2, 5), 'b': (3, 5)},
      ),
  )
  def test_array_from_local_data_single_host_invalid_global_shape(
      self, data, global_shape
  ):
    with self.assertRaises(ValueError):
      jax.make_array_from_process_local_data(
          jax.devices()[0], data, global_shape
      )

  def test_multi_device_array_usage_after_delete(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    shape = (8, 2)
    arr = jax.device_put(np.arange(math.prod(shape), dtype=np.int32),
                         jax.sharding.NamedSharding(global_mesh, P('x')))
    arr.delete()

    with self.assertRaisesRegex(
        RuntimeError, r'Array has been deleted with shape=int32\[16\].'):
      _ = arr + 1

  def test_device_put(self):
    numpy_array = np.array([1, 2, 3])
    arr = jax.device_put(numpy_array, jax.devices()[0])
    self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
    self.assertArraysEqual(arr, numpy_array)
    self.assertEqual(arr._committed, True)
    for i in arr.addressable_shards:
      self.assertArraysEqual(i.data, numpy_array)
      self.assertEqual(i.device, jax.devices()[0])
      self.assertEqual(i.index, (slice(None),))
      self.assertEqual(i.replica_id, 0)

  def test_device_put_array_delete(self):
    arr = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
    arr.delete()
    with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
      arr._check_if_deleted()
    self.assertIsNone(arr._npy_value)
    self.assertIsNone(arr._arrays)

  def test_device_put_to_cpu(self):
    mesh = Mesh(jax.devices(), 'x')
    mesh_cpu = Mesh(jax.devices('cpu'), 'x')
    x = np.zeros(16)
    y = jax.device_put(x, NamedSharding(mesh, P('x')))
    z = jax.device_put(y, NamedSharding(mesh_cpu, P('x')))
    for z_s in z.addressable_shards:
      self.assertArraysEqual(z_s.data, x[z_s.index])

  def test_array_device_get(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    arr, input_data = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
    self.assertArraysEqual(jax.device_get(arr), input_data)

  def test_repr(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    arr, _ = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
    self.assertStartsWith(repr(arr), "Array(")

  def test_empty_repr(self):
    shape = (0, 5)
    dtype = 'float32'
    x = jnp.empty(shape, dtype)
    self.assertEqual(repr(x), f"Array([], shape={shape}, dtype={dtype})")

  def test_jnp_array(self):
    arr = jnp.array([1, 2, 3])
    self.assertIsInstance(arr, array.ArrayImpl)
    self.assertTrue(dispatch.is_single_device_sharding(arr.sharding))
    self.assertEqual(arr._committed, False)
    self.assertFalse(arr.weak_type)

  def test_jnp_array_jit_add(self):
    a = jnp.array([1, 2, 3])
    b = jnp.array([4, 5, 6])
    arr = jax.jit(lambda x, y: x + y)(a, b)
    self.assertIsInstance(arr, array.ArrayImpl)
    self.assertArraysEqual(arr, np.array([5, 7, 9]))
    self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)

  def test_jnp_array_jnp_add(self):
    arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
    self.assertIsInstance(arr, array.ArrayImpl)
    self.assertArraysEqual(arr, np.array([5, 7, 9]))
    self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)

  def test_jnp_array_normal_add(self):
    a = jnp.array([1, 2, 3])
    b = jnp.array([4, 5, 6])
    arr = a + b
    self.assertIsInstance(arr, array.ArrayImpl)
    self.assertArraysEqual(arr, np.array([5, 7, 9]))
    self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)

  def test_array_sharded_astype(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    arr, input_data = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
    arr_float32 = arr.astype(jnp.float32)
    self.assertEqual(arr_float32.dtype, np.float32)
    self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
    self.assertLen(arr_float32.addressable_shards, 8)
    for i in arr_float32.addressable_shards:
      self.assertArraysEqual(i.data, input_data[i.index].astype(np.float32))

  def test_jnp_array_astype(self):
    arr = jnp.array([1, 2, 3])
    arr_float32 = arr.astype(jnp.float32)
    self.assertEqual(arr_float32.dtype, np.float32)
    self.assertArraysEqual(arr_float32, arr.astype(np.float32))

  def test_array_delete_idempotent(self):
    mesh = jtu.create_mesh((2,), ('x',))
    arr = jax.device_put(np.arange(8), jax.sharding.NamedSharding(mesh, P('x')))

    arr.delete()
    self.assertTrue(arr.is_deleted())

    arr.delete()  # Run delete again to check if it's idempotent.
    self.assertTrue(arr.is_deleted())

  def test_sharded_add(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    a, input_data = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
    b, _ = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x')))
    out = a + b
    expected = input_data + input_data
    self.assertArraysEqual(out, expected)
    self.assertLen(out.addressable_shards, 8)
    for i in out.addressable_shards:
      self.assertArraysEqual(i.data, expected[i.index])

  def test_sharded_zeros_like(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    a, input_data = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
    out = jnp.zeros_like(a)
    expected = jnp.zeros(input_data.shape, dtype=a.dtype)
    self.assertArraysEqual(out, expected)
    self.assertLen(out.addressable_shards, 8)
    for i in out.addressable_shards:
      self.assertArraysEqual(i.data, expected[i.index])

  def test_zeros_like(self):
    a = jnp.array([1, 2, 3], dtype=np.int32)
    out = jnp.zeros_like(a)
    expected = np.zeros(a.shape, dtype=np.int32)
    self.assertArraysEqual(out, expected)
    self.assertTrue(dispatch.is_single_device_sharding(out.sharding))

  def test_wrong_num_arrays(self):
    if jax.device_count() < 4:
      self.skipTest('Requires more than 4 devices')
    shape = (8, 2)
    mesh = jtu.create_mesh((1, 2), ('x', 'y'))
    devices = jax.local_devices()[:2]  # Taking up to 2 devices
    s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
    inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
    di_map = s.devices_indices_map(shape)
    bufs = [jax.device_put(inp_data[di_map[d]], d) for d in devices]
    with self.assertRaisesRegex(
        ValueError,
        r'Expected 2 per-device arrays \(this is how many devices are addressable '
        r'by the sharding\), but got 1'):
      array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs[:1], committed=True)

    for buf, d in zip(list(bufs), jax.local_devices()[2:4]):
      bufs.append(jax.device_put(buf, d))
    with self.assertRaisesRegex(
        ValueError,
        r'Expected 2 per-device arrays \(this is how many devices are addressable '
        r'by the sharding\), but got 4'):
      array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)

  def test_arrays_not_in_device_assignment(self):
    if jax.device_count() < 4:
      self.skipTest('Requires more than 4 devices')
    shape = (8, 2)
    mesh = jtu.create_mesh((1, 2), ('x', 'y'))
    # sharding device ids = {0, 1}
    s = jax.sharding.NamedSharding(mesh, P('x'))
    inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
    # _arrays device ids = {2, 3}
    bufs = [jax.device_put(inp_data, d) for d in jax.devices()[2:4]]
    with self.assertRaisesRegex(
        ValueError,
        "Addressable devices and per-device arrays devices do not match. "
        "Sharding contains devices {0, 1} that are not present in per-device "
        "arrays. Per-device arrays contain devices {2, 3} that are not present "
        "in the sharding."):
      array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)

  def test_different_devices_in_arrays_than_sharding(self):
    if jax.device_count() < 3:
      self.skipTest('Requires more than 3 devices')
    shape = (8, 2)
    mesh = jax.sharding.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x'))
    # sharding device ids = {1, 2}
    s = jax.sharding.NamedSharding(mesh, P('x'))
    inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
    # _arrays device ids = {0, 1}
    bufs = [jax.device_put(inp_data, d) for d in jax.devices()[:2]]
    with self.assertRaisesRegex(
        ValueError,
        "Addressable devices and per-device arrays devices do not match. "
        r"Sharding contains devices \{2\} that are not present in per-device "
        r"arrays. Per-device arrays contain devices \{0\} that are not present "
        "in the sharding."):
      array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)

  def test_duplicated_devices_in_arrays(self):
    shape = (8, 2)
    mesh = jtu.create_mesh((1, 2), ('x', 'y'))
    # Sharding device ids = {0, 1}
    s = jax.sharding.NamedSharding(mesh, P('x'))
    inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
    # _arrays device ids = {0, 0}
    bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)]
    with self.assertRaisesRegex(
        ValueError,
        'When making an array from single-device arrays, the input arrays must'
        ' be from distinct devices'):
      array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)

  @parameterized.named_parameters(
      ("mesh_x_y", P("x", "y"), (2, 2)),
      ("mesh_x", P("x"), (2, 4)),
      ("mesh_y", P("y"), (4, 4)),
      ("mesh_none_y", P(None, "y"), (8, 2)),
      ("mesh_none_x", P(None, "x"), (8, 1)),
      ("mesh_xy", P(("x", "y")), (1, 4)),
      ("mesh_replicated", P(()), (8, 4)),
  )
  def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shape):
    shape = (8, 4)
    mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    mps = jax.sharding.NamedSharding(mesh, pspec)
    inp_data = np.arange(5)

    str_expected_shard_shape = str(expected_shard_shape).replace(
        r"(", r"\(").replace(r")", r"\)")
    with self.assertRaisesRegex(
        ValueError,
        f"Expected shard shape {str_expected_shard_shape} doesn't match the "
        "single device array shape"):
      array.make_array_from_callback(shape, mps, lambda idx: inp_data)

  def test_mismatch_dtype(self):
    shape = (8, 2)
    mesh = jtu.create_mesh((1, 2), ('x', 'y'))
    s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
    inp_data = np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
    indices = s.devices_indices_map(shape)
    bufs = [jax.device_put(inp_data[indices[d]], d) for d in mesh.local_devices]
    with self.assertRaisesRegex(
        ValueError,
        "Input buffers to `Array` must have matching dtypes. "
        "Got int32, expected float32"):
      array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)

  def test_array_iter_pmap_sharding(self):
    if jax.device_count() < 2:
      self.skipTest('Test requires >= 2 devices.')
    if config.pmap_shmap_merge.value:
      self.skipTest(
          'Under `pmap_shmap_merge=True`, `y[0]` of sharded `y` will replicate'
          ' because of the indexing operation. '
      )

    x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
    y = jax.pmap(jnp.sin)(x)
    self.assertArraysEqual([list(a.devices())[0] for a in y],
                           y.sharding._device_assignment,
                           allow_object_dtype=True)

    sin_x = iter(np.sin(x))
    for i, j in zip(iter(y), sin_x):
      self.assertIsInstance(i, array.ArrayImpl)
      self.assertArraysAllClose(i, j)

  def test_array_iter_pmap_sharding_last_dim_sharded(self):
    if jax.device_count() < 2:
      self.skipTest('Test requires >= 2 devices.')

    x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
    y = jax.pmap(jnp.sin, out_axes=1)(x)

    for i, j in zip(iter(y), iter(np.sin(x).T)):
      self.assertArraysAllClose(i, j)

  def test_array_iter_mesh_pspec_sharding_multi_device(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    arr, input_data = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))

    for i, j in zip(iter(arr), iter(input_data)):
      self.assertIsInstance(i, array.ArrayImpl)
      self.assertArraysEqual(i, j)

  def test_array_iter_replicated_multi_device(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    arr, input_data = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P(None)))

    for i, j in zip(iter(arr), iter(input_data)):
      self.assertIsInstance(i, array.ArrayImpl)
      self.assertArraysEqual(i, j)
      self.assertLen(i.sharding.device_set, 8)
      self.assertTrue(
        op_shardings.are_hlo_shardings_equal(
            arr.sharding._to_xla_hlo_sharding(arr.ndim),
            i.sharding._to_xla_hlo_sharding(i.ndim)))

  def test_array_getitem_mesh_pspec_sharding_multi_device(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    arr, input_data = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))

    s = arr[2:4, 0:1]
    self.assertIsInstance(s, array.ArrayImpl)
    self.assertArraysEqual(s, input_data[2:4, 0:1])

    p = arr[:2]
    self.assertIsInstance(p, array.ArrayImpl)
    self.assertArraysEqual(p, input_data[:2])

  def test_array_getitem_compile_multi_device_sharding(self):
    def _check(out, inp, shard_shape):
      self.assertArraysEqual(out, inp)
      self.assertEqual(out.sharding.shard_shape(out.shape), shard_shape)
      self.assertNotIsInstance(out.sharding, jax.sharding.SingleDeviceSharding)

    global_mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
    input_shape = (4, 4, 2)
    arr, np_inp = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y', 'z')))

    _check(arr[:, -1, :], np_inp[:, -1, :], (2, 1))
    _check(arr[0, 0, 0], np_inp[0, 0, 0], ())
    _check(arr[-1, -1, :], np_inp[-1, -1, :], (1,))
    _check(arr[:, 1, 0], np_inp[:, 1, 0], (2,))
    _check(arr[:, :, :], np_inp[:, :, :], (2, 2, 1))
    _check(arr[3, :, :], np_inp[3, :, :], (2, 1))
    _check(arr[-1, -1, -1], np_inp[-1, -1, -1], ())
    _check(arr[2, -1, :], np_inp[2, -1, :], (1,))
    _check(arr[2, 3, 1], np_inp[2, 3, 1], ())
    _check(arr[-1], np_inp[-1], (2, 1))
    _check(arr[:], np_inp[:], (2, 2, 1))
    _check(arr[np.array(0), :, :], np_inp[np.array(0), :, :], (2, 1))
    _check(arr[jnp.array(0), :, :], np_inp[jnp.array(0), :, :], (2, 1))
    _check(arr[0, :2, 1], np_inp[0, :2, 1], (2,))
    _check(arr[:, 1::2], np_inp[:, 1::2], (2, 2, 1))
    _check(arr[:, -1:, :], np_inp[:, -1:, :], (2, 1, 1))
    _check(arr[0:6:1], np_inp[0:6:1], (2, 2, 1))
    _check(arr[:4], np_inp[:4], (2, 2, 1))
    _check(arr[::-1], np_inp[::-1], (2, 2, 1))
    _check(arr[1], np_inp[1], (2, 1))

  def test_array_getitem_replicated_multi_device(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    arr, input_data = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P(None)))

    s = arr[2:4, 0:1]
    self.assertIsInstance(s, array.ArrayImpl)
    self.assertArraysEqual(s, np.array([[4], [6]]))
    self.assertLen(s.sharding.device_set, 8)
    self.assertTrue(
        op_shardings.are_hlo_shardings_equal(
            arr.sharding._to_xla_hlo_sharding(arr.ndim),
            s.sharding._to_xla_hlo_sharding(s.ndim)))

    p = arr[:2]
    self.assertIsInstance(p, array.ArrayImpl)
    self.assertArraysEqual(p, input_data[:2])
    self.assertLen(s.sharding.device_set, 8)
    self.assertTrue(
        op_shardings.are_hlo_shardings_equal(
            arr.sharding._to_xla_hlo_sharding(arr.ndim),
            s.sharding._to_xla_hlo_sharding(s.ndim)))

  def test_array_iter_mesh_pspec_sharding_single_device(self):
    if jax.device_count() < 2:
      self.skipTest('Test requires >= 2 devices.')

    single_dev = jax.devices()[1:2]
    mesh = jax.sharding.Mesh(np.array(single_dev), ('x'))
    input_shape = (8, 2)
    arr, input_data = create_array(
        input_shape, jax.sharding.NamedSharding(mesh, P('x')))

    for i, j in zip(arr, iter(input_data)):
      self.assertArraysEqual(i, j)
      self.assertEqual(i.devices(), {single_dev[0]})

  def test_array_shards_committed(self):
    if jax.device_count() < 2:
      self.skipTest('Test requires >= 2 devices.')

    x = jnp.array([1, 2, 3])
    for s in x.addressable_shards:
      self.assertEqual(s.data._committed, x._committed)
      self.assertFalse(s.data._committed)

    y = jax.device_put(x, jax.devices()[1])
    for s in y.addressable_shards:
      self.assertEqual(s.data._committed, y._committed)
      self.assertTrue(s.data._committed)

  def test_array_jnp_array_copy_multi_device(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    arr, _ = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))

    c_arr = jnp.array(arr, copy=True)
    self.assertArraysEqual(arr, c_arr)
    self.assertEqual(arr._committed, c_arr._committed)
    for a, c in safe_zip(arr.addressable_shards, c_arr.addressable_shards):
      self.assertArraysEqual(a.data, c.data)
      self.assertEqual(a.index, c.index)
      self.assertEqual(a.replica_id, c.replica_id)
      self.assertEqual(a.device, c.device)
      self.assertNotEqual(a.data.unsafe_buffer_pointer(),
                          c.data.unsafe_buffer_pointer())

  def test_array_addressable_shards(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    input_shape = (8, 2)
    arr, _ = create_array(
        input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))

    for a in arr.addressable_shards:
      self.assertIsInstance(a.data, array.ArrayImpl)

    x = jnp.array([1, 2, 3])
    self.assertIsInstance(x.addressable_data(0), array.ArrayImpl)

  def test_array_not_hashable(self):
    x = jnp.arange(4)
    with self.assertRaisesRegex(TypeError, "unhashable type"):
      hash(x)

    with self.assertRaisesRegex(TypeError, "unhashable type"):
      jax.jit(hash)(x)

    with self.assertRaisesRegex(TypeError, "unhashable type"):
      jax.vmap(hash)(x)

  def test_shape_dtype_struct_sharding_jit(self):
    mesh = jtu.create_mesh((8,), ('x'))
    s = jax.sharding.NamedSharding(mesh, P('x'))

    x_dummy = jax.ShapeDtypeStruct(
        shape=(16,),
        dtype=jnp.dtype('float32'),
        sharding=s)

    def f(x):
      return x * 2

    c = jax.jit(f).lower(x_dummy).compile()
    input_shardings, output_shardings = c.input_shardings, c.output_shardings
    self.assertLen(input_shardings, 2)
    self.assertEqual(input_shardings[1], {})
    self.assertEqual(input_shardings[1], {})

    self.assertTrue(
        op_shardings.are_hlo_shardings_equal(
            input_shardings[0][0]._to_xla_hlo_sharding(x_dummy.ndim),
            s._to_xla_hlo_sharding(x_dummy.ndim)))
    self.assertTrue(
        op_shardings.are_hlo_shardings_equal(
            output_shardings._to_xla_hlo_sharding(x_dummy.ndim),
            s._to_xla_hlo_sharding(x_dummy.ndim)))

  # TODO(b/399879011): GPU is the only platform that has an implementation for
  # this, which exists in py_client.cc. Ideally, this would be replaced with
  # some kind of auto-defrag-on-OOM.
  @jtu.run_on_devices('gpu')
  def test_defragment(self):
    # Since the GPU implementation is in py_client.cc, it cannot be exposed via
    # the PjRt C API.
    if xb.using_pjrt_c_api():
      self.skipTest('Manual defragment not exposed via PJRT C API')

    # Create a few arrays
    global_mesh = jtu.create_mesh((jax.local_device_count(),), ('x',))
    shape = (8, 2)
    mpsharding = jax.sharding.NamedSharding(global_mesh, P('x',))
    arr1, data = create_array(shape, mpsharding)
    arr2, _ = create_array(shape, mpsharding, data)
    arr3, _ = create_array(shape, mpsharding, data)

    # Delete one of them
    arr2.delete()

    # Defragment.
    xb.get_backend().defragment()

    # Sanity check remaining arrays
    self.assertArraysEqual(arr1, data)
    self.assertArraysEqual(arr1 + arr3, data * 2)

    # TODO(skyewm): check that defragmentation actually happened. I originally
    # thought to do this with unsafe_buffer_pointer(), but that's not always the
    # device memory address. Other ideas include causing enough fragmentation to
    # OOM, and exposing allocator stats in Python.

  def test_on_device_size_in_bytes(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    a, _ = create_array(
        (8, 2), jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
    shard_size = a.addressable_shards[0].data.on_device_size_in_bytes()
    self.assertGreaterEqual(shard_size, 4 * 2)
    self.assertEqual(shard_size * len(a.global_shards),
                     a.on_device_size_in_bytes())

  def test_array_is_ready(self):
    x = jax.device_put(jnp.arange(8.), jax.devices()[0])
    x.is_ready()  # doesn't crash

  def test_process_allgather_single_host(self):
    x = jnp.arange(8.)
    out = multihost_utils.process_allgather(x, tiled=True)
    self.assertEqual(out.shape, x.shape)
    self.assertArraysEqual(out, x)

    out = multihost_utils.process_allgather(x)
    self.assertEqual(out.shape, (1, x.shape[0]))
    self.assertArraysEqual(out, np.expand_dims(x, axis=0))

  def test_broadcast_one_to_all_single_host(self):
    x = jnp.arange(8, dtype=jnp.uint8)
    out = multihost_utils.broadcast_one_to_all(x)
    self.assertEqual(out.shape, x.shape)
    self.assertEqual(out.dtype, x.dtype)
    self.assertArraysEqual(out, x)

  @jtu.sample_product(
    dtype=jtu.dtypes.all,
    shape=[(), (10), (2, 3)],
  )
  @jtu.run_on_devices("cpu")
  def test_buffer_protocol(self, dtype, shape):
    rng = jtu.rand_default(self.rng())
    x = rng(shape, dtype)
    y = jax.device_put(x)
    if dtype == jax.dtypes.bfloat16:
      with self.assertRaisesRegex(
          BufferError,
          'Buffers of type BF16 are not supported by the Python buffer '
          'protocol.'
      ):
        memoryview(y)
      return

    x_bytes = memoryview(x).tobytes()
    y_bytes = memoryview(y).tobytes()
    self.assertEqual(x_bytes, y_bytes)

  @jtu.run_on_devices("cpu")
  def test_buffer_protocol_deletion(self):
    rng = jtu.rand_default(self.rng())
    x = rng((3, 4), np.float32)
    y = jax.device_put(x)
    x_bytes = memoryview(x).tobytes()
    y_view = memoryview(y)
    # The array does not actually get deleted until any external reference is
    # dropped. Arguably we should make calling delete() in these circumstances
    # return an error instead, but that would be a behavior change for existing
    # users.
    y.delete()
    y_bytes = y_view.tobytes()
    self.assertEqual(x_bytes, y_bytes)

  def test_array_copy_to_host_async(self):
    global_mesh = jtu.create_mesh((2, 2), ('x', 'y'))
    x = jax.jit(lambda: jnp.arange(8.),
                out_shardings=jax.NamedSharding(global_mesh, P(None)))()
    self.assertLen(x.sharding.device_set, 4)
    x.copy_to_host_async()  # doesn't crash
    self.assertArraysEqual(np.arange(8.), x)

  def test_array_fully_replicated_shard(self):

    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    inp_shape = (8, 2)
    arr, inp_data = create_array(
        inp_shape, jax.sharding.NamedSharding(global_mesh, P()))
    fs = arr._fully_replicated_shard()
    self.assertEqual(fs.shape, inp_shape)
    self.assertTrue(dispatch.is_single_device_sharding(fs.sharding))
    self.assertArraysEqual(fs, inp_data)
    self.assertArraysEqual(arr.addressable_data(0), inp_data)

  def test_shard_array_to_fully_replicated(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    sharding = jax.sharding.NamedSharding(global_mesh, P())
    arr = jnp.arange(16)
    self.assertFalse(arr._committed)
    self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
    out = jax.jit(lambda x: x * 2, in_shardings=sharding)(arr)
    self.assertTrue(out.sharding.is_fully_replicated)
    self.assertArraysEqual(out, arr * 2)

  def test_fully_replicated_donated_array_is_deleted(self):
    global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    sharding = jax.sharding.NamedSharding(global_mesh, P())
    arr = jnp.arange(16)
    arr_copy = arr.copy()
    self.assertFalse(arr._committed)
    self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
    out = jax.jit(lambda x: x * 2, in_shardings=sharding, donate_argnums=0)(arr)
    self.assertTrue(out.sharding.is_fully_replicated)
    self.assertArraysEqual(out, arr_copy * 2)
    self.assertTrue(arr.is_deleted())

  @parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
  def test_shards_have_correct_dtype(self, dtype):
    x = jnp.ones((), dtype=dtype)
    for shard in x.addressable_shards:
      self.assertEqual(shard.data.dtype, dtype)

  def test_make_array_from_callback_global_array(self):
    mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    sharding = jax.sharding.NamedSharding(mesh, P())
    np_inp = np.arange(16).reshape(8, 2)
    arr = jax.device_put(np_inp, sharding)

    out = jax.make_array_from_callback(np_inp.shape, sharding,
                                       lambda idx: arr[idx])
    self.assertArraysEqual(out, arr)
    self.assertEqual(out.sharding, sharding)

    sharding2 = NamedSharding(mesh, P('x', 'y'))
    arr2 = jax.device_put(np_inp, sharding2)
    out2 = jax.make_array_from_callback(np_inp.shape, sharding2,
                                       lambda idx: arr2[idx])
    self.assertArraysEqual(out2, arr2)
    self.assertEqual(out2.sharding, sharding2)

  def test_make_array_from_process_data_single_host_data_sharding(self):
    mesh = jtu.create_mesh((2, 1), ('x', 'y'))
    data = np.ones((256, 512))
    s = jax.NamedSharding(mesh, P('x'))
    result = jax.make_array_from_process_local_data(s, data)
    self.assertArraysEqual(result, data)
    self.assertEqual(result.sharding, s)

    with jax.set_mesh(mesh):
      result = jax.make_array_from_process_local_data(P('x'), data)
      self.assertArraysEqual(result, data)
      self.assertEqual(result.sharding, s)

  @parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
  @jtu.run_on_devices("gpu")
  def test_pinned_host_npy_value_doesnt_cache(self, dtype):
    # see https://github.com/jax-ml/jax/issues/26216
    d_tensor = jnp.array(0, dtype=dtype)
    d_sharding = d_tensor.sharding
    h_sharding = d_sharding.with_memory_kind("pinned_host")
    h_tensor = jax.device_put(d_tensor, h_sharding)
    np.array(h_tensor)
    self.assertIsNone(h_tensor._npy_value)

  def test_make_array_from_single_device_arrays_no_dtype_error(self):
    mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
    with self.assertRaisesRegex(
        ValueError,
        'If the Array has no addressable shards, `dtype` must be provided via '
        'the `dtype` argument to `jax.make_array_from_single_device_arrays`.'):
      jax.make_array_from_single_device_arrays((8, 2), s, [])

  def test_make_array_from_single_device_arrays_bad_dtype_error(self):
    s = jax.sharding.SingleDeviceSharding(jax.devices()[0])
    shape = (8, 2)
    np_inp = np.arange(math.prod(shape)).reshape(shape)
    arr = jax.device_put(np_inp, s)
    with self.assertRaisesRegex(
        ValueError,
        'If `dtype` is provided to `jax.make_array_from_single_device_arrays`, '
        'it must match the dtype of the addressable shards.'):
      jax.make_array_from_single_device_arrays(
          shape, s, [arr], dtype=jnp.float32)

  @jtu.with_explicit_mesh((2,), ('x',))
  def test_unreduced_printing(self, mesh):
    x = jax.device_put(jnp.arange(8., dtype='float32'), P('x'))
    x = jax.lax.reduce_sum(x, [0], out_sharding=P(unreduced={'x'}))
    self.assertIn('nreduced', str(x.sharding))
    self.assertIn('Array(shape=(), dtype=float32, sharding=', str(x))


class ShardingTest(jtu.JaxTestCase):

  def test_mesh_pspec_sharding_interface(self):
    mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    pspec = P('y', 'x')
    global_shape = (8, 4)
    mp_sharding = jax.sharding.NamedSharding(mesh, pspec)
    di_map = mp_sharding.devices_indices_map(global_shape)
    hlo_sharding = mp_sharding._to_xla_hlo_sharding(len(global_shape))
    device_assignment = mp_sharding._device_assignment

    self.assertEqual(di_map[mesh.devices.flat[0]], (slice(0, 4), slice(0, 1)))
    self.assertArraysEqual(device_assignment, list(mesh.devices.flat),
                           allow_object_dtype=True)
    self.assertTrue(hlo_sharding.is_tiled())
    self.assertListEqual(hlo_sharding.tile_assignment_dimensions(), [2, 4])
    self.assertListEqual(hlo_sharding.tile_assignment_devices(),
                         [0, 2, 4, 6, 1, 3, 5, 7])

  @jtu.thread_unsafe_test()  # cache_info isn't thread-safe
  def test_util_clear_cache(self):
    mesh = jtu.create_mesh((1,), ('x',))
    s = NamedSharding(mesh, P())
    s.devices_indices_map((8,))
    jax.clear_caches()
    s.devices_indices_map((8,))
    c = common_devices_indices_map.cache_info()
    self.assertEqual(c.currsize, 1)

  @parameterized.named_parameters(
      ("mesh_x_y", P("x", "y")),
      ("mesh_x", P("x")),
      ("mesh_y", P("y")),
      ("mesh_none_y", P(None, "y")),
      ("mesh_none_x", P(None, "x")),
      ("mesh_xy", P(("x", "y"))),
      ("mesh_fully_replicated", P()),
  )
  def test_op_sharding_indices(self, pspec):
    shape = (8, 4)
    mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    mps = jax.sharding.NamedSharding(mesh, pspec)
    ops = GSPMDSharding(
        list(mesh.devices.flat), mps._to_xla_hlo_sharding(len(shape)))
    self.assertDictEqual(
        ops.devices_indices_map(shape), mps.devices_indices_map(shape))

  @parameterized.named_parameters(
      ("mesh_x_y", P("x", "y"), (2, 2)),
      ("mesh_x", P("x"), (2, 4)),
      ("mesh_y", P("y"), (4, 4)),
      ("mesh_none_y", P(None, "y"), (8, 2)),
      ("mesh_none_x", P(None, "x"), (8, 1)),
      ("mesh_xy", P(("x", "y")), (1, 4)),
      ("mesh_fully_replicated", P(), (8, 4)),
  )
  def test_shard_shape(self, pspec, expected_shard_shape):
    shape = (8, 4)
    mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    mps = jax.sharding.NamedSharding(mesh, pspec)
    self.assertEqual(mps.shard_shape(shape), expected_shard_shape)

  def test_uneven_shard_error(self):
    mesh = jtu.create_mesh((4, 2), ('x', 'y'))
    mps = jax.sharding.NamedSharding(mesh, P('x', 'y'))
    with self.assertRaisesRegex(
        ValueError,
        r"Sharding.*implies that array axis 1 is partitioned 2 times, but the "
        r"dimension size is 3 \(full shape: \(8, 3\), per-dimension tiling "
        r"factors: \[4, 2\] should evenly divide the shape\)"):
      mps.shard_shape((8, 3))

  @jtu.ignore_warning(category=DeprecationWarning)
  @jtu.thread_unsafe_test()  # cache_info isn't thread-safe
  def test_pmap_sharding_hash_eq(self):
    if jax.device_count() < 2:
      self.skipTest('Test needs >= 2 devices.')
    if config.pmap_shmap_merge.value:
      self.skipTest(
          'There is not an equivalent cache to test when pmap_shmap_merge=True.'
      )

    shape = (2, 2)
    num_elements = math.prod(shape)
    inp_data = np.arange(num_elements).reshape(shape)
    out = jax.pmap(lambda x: x)(inp_data)
    self.assertIsInstance(out.sharding, jax.sharding.PmapSharding)
    # Populate the device_indices_map cache.
    _ = out.sharding.devices_indices_map(shape)
    cache_info1 = pmap_sharding_devices_indices_map.cache_info()

    inp_data2 = np.arange(num_elements, num_elements + num_elements).reshape(shape)
    out2 = jax.pmap(lambda x: x)(inp_data2)
    # Populate the device_indices_map cache.
    _ = out2.sharding.devices_indices_map(shape)
    cache_info2 = pmap_sharding_devices_indices_map.cache_info()

    self.assertGreater(cache_info2.hits, cache_info1.hits + 1)
    self.assertEqual(cache_info2.misses, cache_info1.misses)

  def test_is_compatible_error(self):
    shape = (8, 2)
    mesh = jtu.create_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
    mps = jax.sharding.NamedSharding(mesh, P(None, ('mdl',), None, None))

    with self.assertRaisesRegex(
        ValueError,
        r"Sharding NamedSharding.*PartitionSpec\(None, 'mdl', None, None\).*\)"
        ' is only valid for values of rank at least 4, but was applied to a'
        ' value of rank 2'):
      mps.check_compatible_aval(shape)

  def test_is_subclass(self):
    # array version of api_test.py::APITest::test_is_subclass
    self.assertTrue(issubclass(array.ArrayImpl, jax.Array))
    self.assertFalse(issubclass(array.ArrayImpl, np.ndarray))

  def test_gspmd_sharding_repr(self):
    op = xc.OpSharding()
    op.type = xc.OpSharding.Type.OTHER
    op.tile_assignment_dimensions = [4, 1, 2]
    op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
    op.replicate_on_last_tile_dim = True
    s = GSPMDSharding(jax.devices(), op)
    # memory kind also appears in the repr but only for TPU.
    self.assertIn(
        'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 '
        'last_tile_dim_replicate}', repr(s))

    op2 = xc.OpSharding()
    op2.type = xc.OpSharding.Type.REPLICATED
    s2 = GSPMDSharding(jax.devices(), op2)
    # memory kind also appears in the repr but only for TPU.
    self.assertIn('GSPMDSharding({replicated}', repr(s2))

  @parameterized.named_parameters(
      ("2d_mesh_x",                (1, 1), P("x", "y")),
      ("2d_mesh_x_y",              (4, 2), P("x", "y")),
      ("2d_mesh_empty",            (2, 1), P()),
      ("2d_mesh_p_none",           (2, 1), P(None)),
      ("2d_mesh_none_none",        (2, 1), P(None, None)),
      ("2d_mesh_tuple_empty",      (2, 1), P((),)),
      ("2d_mesh_x_none",           (2, 1), P(('x',), None)),
      ("2d_mesh_xy_none",          (2, 1), P(('x', 'y'), None)),
      ("2d_mesh_x_tuple_empty",    (2, 1), P('x', (), (), ())),
      ("2d_mesh_3_tuple_empty",    (2, 1), P((), (), ())),
      ("3d_mesh2_x_none_none",     (1, 2, 4), P('x', None, None)),
      ("3d_mesh2_x_y_none",        (1, 1, 4), P('x', 'y', None)),
      ("3d_mesh2_xy_none",         (1, 1, 4), P(('x', 'y'), None)),
  )
  def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec):
    if len(mesh_shape) == 2:
      axis_names = ('x', 'y')
    elif len(mesh_shape) == 3:
      axis_names = ('x', 'y', 'z')
    else:
      axis_names = ('x',)
    mesh = jtu.create_mesh(mesh_shape, axis_names)
    mps = jax.sharding.NamedSharding(mesh, pspec)
    shape = (8, 2, 4)
    mps_op_sharding = mps._to_xla_hlo_sharding(len(shape))
    ops_ifr = op_shardings.is_hlo_sharding_replicated(mps_op_sharding)
    self.assertEqual(mps.is_fully_replicated, ops_ifr)

  def test_pmap_sharding_repr(self):
    if jax.device_count() < 2:
      self.skipTest('Test needs >= 2 devices.')
    out = jax.pmap(lambda x: x)(jnp.arange(2.))
    str(out.sharding)  # doesn't crash
    repr(out.sharding)  # doesn't crash

  def test_pspec_tuple(self):
    pspec = P('x', 'y', 'z')
    self.assertEqual(pspec, ('x', 'y', 'z'))
    self.assertEqual(pspec.index('z'), 2)
    self.assertEqual(hash(P(None, 'x', 'y', 'z')), hash(P((), 'x', 'y', 'z')))

  @parameterized.named_parameters(
      ('sharded_dim_0', (4, 2), 0),
      ('sharded_dim_1_0', (4, 2), 1),
      ('sharded_dim_2', (4, 2, 4), 2),
      ('sharded_dim_1_1', (2, 4), 1)
  )
  @jtu.ignore_warning(category=DeprecationWarning)
  def test_default_pmap_sharding(self, shape, sharded_dim):
    if jax.device_count() < 4:
      self.skipTest('Test needs >= 4 devices.')

    inp = jnp.arange(math.prod(shape)).reshape(shape)
    if config.pmap_shmap_merge.value:
      out = jax.pmap(lambda x: x, in_axes=sharded_dim, axis_name='x')(inp)
      actual_sharding = out.sharding
      expected_sharding = jax.sharding.NamedSharding(
          jax.sharding.Mesh(jax.devices()[: shape[sharded_dim]], 'x'),
          jax.P('x'),
      )
      self.assertEqual(actual_sharding.spec, expected_sharding.spec)
      self.assertEqual(actual_sharding._device_assignment, expected_sharding._device_assignment)
    else:
      compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile()
      # TOOD(dsuo): Investigate why
      # `compiled._executable.unsafe_call.in_handler.in_shardings` is of type
      # `GSPMDSharding` when `pmap_shmap_merge=True`. It should be
      # `NamedSharding`.
      actual_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings
      expected_sharding = jax.sharding.PmapSharding.default(shape, sharded_dim)
      self.assertEqual(actual_sharding.sharding_spec, expected_sharding.sharding_spec)
      self.assertEqual(actual_sharding._device_assignment, expected_sharding._device_assignment)

  @jtu.ignore_warning(category=DeprecationWarning)
  def test_default_pmap_sharding_with_devices(self):
    if jax.device_count() < 4:
      self.skipTest('Test needs >= 4 devices.')

    devs = jax.devices()
    new_order = (devs[0], devs[3], devs[2], devs[1])
    ps = jax.sharding.PmapSharding.default((4, 2), devices=new_order)
    self.assertEqual(ps._device_assignment, new_order)

  @jtu.ignore_warning(category=DeprecationWarning)
  def test_default_pmap_sharding_replicated(self):
    x = np.zeros((len(jax.local_devices()), 8), dtype=np.float32)
    x = jax.pmap(lambda x: x, in_axes=0, out_axes=None, axis_name='x')(x)
    if config.pmap_shmap_merge.value:
      expected_sharding = jax.sharding.NamedSharding(
          mesh=jax.sharding.Mesh(jax.local_devices(), 'x'),
          spec=jax.P(),
      )
      self.assertEqual(x.sharding, expected_sharding)
    else:
      ps = jax.sharding.PmapSharding.default(
          shape=(8,), sharded_dim=None,
          devices=jax.local_devices())
      self.assertEqual(x.sharding, ps)

  def test_mesh_repr(self):
    mesh = jtu.create_mesh((1, 1), ('x', 'y'))
    mesh_repr = repr(mesh)
    self.assertIn('axis_sizes', mesh_repr)
    self.assertIn('axis_names', mesh_repr)

  def test_are_shardings_equivalent(self):
    mesh = jtu.create_mesh((1,), ('x'))
    mesh2 = jtu.create_mesh((2, 1), ('x', 'y'))

    s1 = jax.sharding.NamedSharding(mesh, P('x'))
    s2 = jax.sharding.SingleDeviceSharding(jax.devices()[0])

    self.assertTrue(s1.is_equivalent_to(s2, 2))

    s3 = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())).sharding
    s4 = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())).sharding
    self.assertTrue(s3.is_equivalent_to(s4, 2))

    self.assertFalse(s1.is_equivalent_to(s3, 2))
    self.assertFalse(s2.is_equivalent_to(s3, 2))

    s5 = jax.sharding.NamedSharding(mesh2, P('x', 'y'))

    op1 = xc.OpSharding()
    op1.type = xc.OpSharding.Type.REPLICATED
    s6 = GSPMDSharding([jax.devices()[0]], op1)

    s7 = GSPMDSharding(jax.devices(), op1)

    # The OpSharding is replicated but the Sharding itself are on different
    # devices.
    self.assertFalse(s6.is_equivalent_to(s7, 2))

    op2 = xc.OpSharding()
    op2.type = xc.OpSharding.Type.OTHER
    op2.tile_assignment_devices = [0, 1]
    op2.tile_assignment_dimensions = [2, 1]
    s8 = GSPMDSharding(list(mesh2.devices.flat), op2)

    self.assertTrue(s1.is_equivalent_to(s6, 2))
    self.assertTrue(s5.is_equivalent_to(s8, 2))
    self.assertFalse(s5.is_equivalent_to(s2, 2))

    s9 = jax.sharding.NamedSharding(mesh2, P('y'))

    op3 = xc.OpSharding()
    op3.type = xc.OpSharding.Type.OTHER
    op3.tile_assignment_devices = [0, 1]
    op3.tile_assignment_dimensions = [1, 1, 2]
    op3.replicate_on_last_tile_dim = True
    s10 = GSPMDSharding(list(mesh2.devices.flat), op3)

    self.assertTrue(s9.is_equivalent_to(s10, 2))

  def test_devices_indices_map_good_error_message(self):
    shape = (1, 2)
    mesh = jtu.create_mesh((2, 2), ('x', 'y'))
    s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
    with self.assertRaisesRegex(
        ValueError,
        "Sharding.*implies that array axis 0 is partitioned 2 times, but the "
        "dimension size is 1"):
      s.devices_indices_map(shape)

  def test_scalar_input_wrong_pspec(self):
    mesh = jtu.create_mesh((1, ), ('x'))
    shape = ()
    s = jax.sharding.NamedSharding(mesh, P('x'))
    with self.assertRaisesRegex(
        ValueError,
        r"For scalars the PartitionSpec should be P()"):
      s.check_compatible_aval(shape)

  def test_mesh_caching_during_construction(self):
    if jax.device_count() < 2:
      raise unittest.SkipTest("Requires >=2 devices")
    mesh1 = jax.sharding.Mesh(jax.devices(), 'x')
    mesh2 = jax.sharding.Mesh(jax.devices(), 'x')

    self.assertIs(mesh1, mesh2)

  def test_mesh_str(self):
    mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
    self.assertEqual(
        str(mesh), "Mesh('x': 2, 'y': 2, 'z': 2, axis_types=(Auto, Auto, Auto))"
    )

  def test_make_array_from_callback_error(self):
    mesh_shape = (2, 3)
    global_shape = tuple(np.square(mesh_shape))
    mesh = jtu.create_mesh(mesh_shape, ('x', 'y'), iota_order=True)
    pspec = P('x', 'y')
    sharding = jax.sharding.NamedSharding(mesh, pspec)
    n = math.prod(global_shape)
    global_x = jnp.arange(n).astype('uint32').reshape(global_shape)

    def f(arr):
      return array.make_array_from_callback(arr.shape, sharding, lambda i: arr[i])

    out = f(global_x)
    self.assertEqual(out.shape, global_shape)

    msg = "jax.make_array_from_callback cannot be called within a traced context"
    with self.assertRaisesRegex(jax.errors.UnexpectedTracerError, msg):
      jax.jit(f)(global_x)

  def test_make_array_from_single_device_arrays_error(self):
    x = jnp.arange(10)
    sharding = x.sharding

    def f(x):
      return jax.make_array_from_single_device_arrays(x.shape, sharding, [x])

    msg = "jax.make_array_from_single_device_arrays requires a list of concrete arrays"
    with self.assertRaisesRegex(ValueError, msg):
      jax.jit(f)(x)

  def test_make_array_from_single_device_arrays_nonlist_error(self):
    x = jnp.arange(10)
    sharding = x.sharding

    def f(x):
      return jax.make_array_from_single_device_arrays(x.shape, sharding, x)

    msg = "jax.make_array_from_single_device_arrays `arrays` argument"
    with self.assertRaisesRegex(TypeError, msg):
      jax.jit(f)(x)

  def test_make_array_from_single_device_arrays_tuple(self):
    mesh = jtu.create_mesh((2, 2), ('x', 'y'))
    shape = (8, 8)
    s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
    inp_data = np.arange(math.prod(shape)).reshape(shape)

    arrays = tuple(
        jax.device_put(inp_data[index], d)
        for d, index in s.addressable_devices_indices_map(shape).items())

    jax.make_array_from_single_device_arrays(shape, s, arrays)  # doesn't crash

  def test_make_array_from_single_device_arrays_bad_inputs(self):
    x = jnp.arange(10)
    mesh = jtu.create_mesh((2,), ('x',))
    s = jax.sharding.NamedSharding(mesh, P('x'))
    x = jax.device_put(x, s)

    msg = ("When making an array from single-device arrays the input arrays "
           "must have one shard each. An argument array had 2 shard\\(s\\).")
    with self.assertRaisesRegex(ValueError, msg):
      jax.make_array_from_single_device_arrays(x.shape, s, [x, x])

  def test_gspmd_sharding_hash_eq(self):
    mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z'))
    ns = NamedSharding(mesh, P('x', 'y', 'z'))

    x1 = GSPMDSharding(mesh._flat_devices_tuple, ns._to_xla_hlo_sharding(3))
    x2 = GSPMDSharding.get_replicated(mesh._flat_devices_tuple)

    self.assertEqual(x1, x2)
    self.assertEqual(hash(x1), hash(x2))

  def test_device_attr(self):
    # For single-device arrays, x.device returns the device
    x = jnp.ones((2, 10))
    self.assertEqual(x.device, list(x.devices())[0])

    # For sharded arrays, x.device returns the sharding
    mesh = jtu.create_mesh((2,), ('x',))
    sharding = jax.sharding.NamedSharding(mesh, P('x'))
    x = jax.device_put(x, sharding)
    self.assertEqual(x.device, sharding)

  def test_to_device(self):
    device = jax.devices()[-1]
    mesh = jtu.create_mesh((2,), ('x',))
    sharding = jax.sharding.NamedSharding(mesh, P('x'))

    x = jnp.ones((2, 10))

    x_device = x.to_device(device)
    x_sharding = x.to_device(sharding)

    self.assertEqual(x_device.device, device)
    self.assertEqual(x_sharding.device, sharding)

  def test_mesh_with_axis_name_none(self):
    with self.assertRaisesRegex(ValueError, 'Mesh axis names cannot be None.'):
      jax.sharding.Mesh(jax.devices(), (None, 'x'))

  def test_mesh_axis_types_mismatch(self):
    with self.assertRaisesRegex(
        ValueError,
        'Number of axis names should match the number of axis_types'):
      jtu.create_mesh((2, 1), ('x', 'y'),
                      axis_types=jax.sharding.AxisType.Auto)

    with self.assertRaisesRegex(
        ValueError,
        'Number of axis names should match the number of axis_types'):
      jax.sharding.AbstractMesh((2, 1), ('x', 'y'),
                                axis_types=jax.sharding.AxisType.Auto)

    with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"):
      AbstractMesh((2,), ('x',), axis_types=("explicit",))

    with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"):
      AbstractMesh((2,), ('x',), axis_types="explicit")

    with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"):
      AbstractMesh((2, 2), ('x', 'y'),
                   axis_types=("explicit", AxisType.Explicit))

  def test_make_mesh_axis_types(self):
    Auto, Explicit, Manual = AxisType.Auto, AxisType.Explicit, AxisType.Manual

    mesh1 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
    mesh2 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
    self.assertEqual(mesh1, mesh2)

    if deprecations.is_accelerated('jax-make-mesh-default-explicit'):
      mesh = jax.make_mesh((1, 1), ('x', 'y'))
      self.assertTupleEqual(mesh.axis_types, (AxisType.Explicit,) * 2)
    else:
      mesh = jax.make_mesh((1, 1), ('x', 'y'),
                          axis_types=(AxisType.Explicit,) * 2)
      self.assertTupleEqual(mesh.axis_types, (AxisType.Explicit,) * 2)

    mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'),
                         axis_types=(Explicit, Auto, Manual))

    self.assertEqual(mesh.explicit_axes, ('x',))
    self.assertEqual(mesh.auto_axes, ('y',))
    self.assertEqual(mesh.manual_axes, ('z',))

    with self.assertRaisesRegex(
        ValueError,
        'Number of axis names should match the number of axis_types'):
      jax.make_mesh((1, 1), ('data', 'model'), axis_types=Explicit)

    mesh1 = jax.make_mesh((1, 1, 1, 1, 1), ('a', 'b', 'c', 'd', 'e'),
                          axis_types=(Explicit, Auto, Auto, Explicit, Explicit))
    mesh2 = jax.make_mesh((1, 1, 1, 1, 1), ('a', 'b', 'c', 'd', 'e'),
                          axis_types=(Explicit, Auto, Auto, Explicit, Auto))
    self.assertNotEqual(mesh1, mesh2)
    self.assertNotEqual(hash(mesh1), hash(mesh2))

  def test_memory_kind_with_abstract_mesh(self):
    abstract_mesh = AbstractMesh((2,), ('x',))
    ns = NamedSharding(abstract_mesh, P(), memory_kind='pinned_host')
    self.assertEqual(ns.memory_kind, 'pinned_host')

    ns = NamedSharding(abstract_mesh, P())
    self.assertIsNone(ns.memory_kind)

    with self.assertRaisesRegex(
        ValueError, 'Got invalid memory kind'):
      NamedSharding(abstract_mesh, P(), memory_kind='weird_device')

  def test_pspec_mix_axis_types(self):
    mesh = AbstractMesh(
        (2, 2, 2, 2), ('a', 'b', 'c', 'd'),
        axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Auto,
                    AxisType.Manual))
    aval = jax.core.ShapedArray((16, 8, 4, 2), np.float32)

    out = aval.update(sharding=NamedSharding(mesh, P(('a', 'b', 'c'), 'd')))
    self.assertEqual(out.sharding.spec, P(('a', 'b'), None, None, None))

    out = aval.update(sharding=NamedSharding(mesh, P(('a', 'c'), 'b', 'd')))
    self.assertEqual(out.sharding.spec, P('a', 'b', None, None))

    out = aval.update(sharding=NamedSharding(mesh, P(('a', 'b'), 'c', 'd')))
    self.assertEqual(out.sharding.spec, P(('a', 'b'), None, None, None))

    with self.assertRaisesRegex(
        ValueError,
        'Tuple subset of `PartitionSpec` cannot contain `Manual` mixed with'
        ' `Auto` or `Explicit`'):
      aval.update(sharding=NamedSharding(mesh, P(('a', 'd'), 'b', 'c')))

  def test_aval_str_short(self):
    mesh = AbstractMesh(
        (2, 2, 2), ('a', 'b', 'c'),
        axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Manual))

    s = NamedSharding(mesh, P(unreduced={'a'}, reduced={'b'}))
    aval = jax.core.ShapedArray((1, 1, 1, 1), np.float32, sharding=s,
                                vma=frozenset('c'))
    self.assertEqual(aval.str_short(True), 'f32[1,1,1,1]{V:c, U:a, R:b}')

    s = NamedSharding(mesh, P(unreduced={'a'}))
    aval = jax.core.ShapedArray((1, 1, 1, 1), np.float32, sharding=s,
                                vma=frozenset('c'))
    self.assertEqual(aval.str_short(True), 'f32[1,1,1,1]{V:c, U:a}')

    s = NamedSharding(mesh, P(unreduced={'a'}))
    aval = jax.core.ShapedArray((1, 1, 1, 1), np.float32, sharding=s)
    self.assertEqual(aval.str_short(True), 'f32[1,1,1,1]{U:a}')

    s = NamedSharding(mesh, P())
    aval = jax.core.ShapedArray((1, 1, 1, 1), np.float32, sharding=s,
                                vma=frozenset('c'))
    self.assertEqual(aval.str_short(True), 'f32[1,1,1,1]{V:c}')

    aval = jax.core.ShapedArray((1, 1, 1, 1), np.float32)
    self.assertEqual(aval.str_short(True), 'f32[1,1,1,1]')

  def test_modify_spec_auto_unreduced(self):
    mesh = AbstractMesh(
        (2, 2, 2), ('a', 'b', 'c'),
        axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Auto))
    spec = P(unreduced={'a', 'b', 'c'})
    out = core.modify_spec_for_auto_manual(spec, mesh)
    self.assertEqual(out, P(unreduced={'a', 'b'}))

    spec = P(reduced={'a', 'b', 'c'})
    out = core.modify_spec_for_auto_manual(spec, mesh)
    self.assertEqual(out, P(reduced={'a', 'b'}))

    spec = P(unreduced={'a', 'b'}, reduced={'c'})
    out = core.modify_spec_for_auto_manual(spec, mesh)
    self.assertEqual(out, P(unreduced={'a', 'b'}))

    spec = P(unreduced={'a', 'c'}, reduced={'b'})
    out = core.modify_spec_for_auto_manual(spec, mesh)
    self.assertEqual(out, P(unreduced={'a'}, reduced={'b'}))

    spec = P(unreduced={'c'}, reduced={'a', 'b'})
    out = core.modify_spec_for_auto_manual(spec, mesh)
    self.assertEqual(out, P(reduced={'a', 'b'}))

  def test_pspec_unreduced(self):
    pspec = P('a', 'b', None, unreduced={'c'}, reduced={'d'})
    self.assertEqual(
        repr(pspec),
        "PartitionSpec('a', 'b', None, unreduced={'c'}, reduced={'d'})")

    pspec1 = P('a', 'b', None, unreduced={'c'})
    self.assertEqual(repr(pspec1),
                     "PartitionSpec('a', 'b', None, unreduced={'c'})")

    pspec2 = P('a', 'b', None, unreduced={'c'})
    self.assertEqual(pspec1, pspec2)

    pspec3 = P('a', 'b', None, unreduced={'d'})
    self.assertNotEqual(pspec1, pspec3)

    out = P('x', unreduced={'z'}) + P('a', unreduced={'b'})
    self.assertEqual(out, P('x', 'a', unreduced={'z', 'b'}))

    pspec4 = P('x', unreduced={'y'})
    self.assertEqual(repr(pspec4),
                     "PartitionSpec('x', unreduced={'y'})")

    pspec5 = P(None, None, unreduced={'x'})
    self.assertEqual(repr(pspec5),
                     "PartitionSpec(None, None, unreduced={'x'})")

    pspec6 = P(None, unreduced={'x'})
    self.assertEqual(repr(pspec6), "PartitionSpec(None, unreduced={'x'})")

    pspec7 = P(unreduced={'x'})
    self.assertEqual(repr(pspec7), "PartitionSpec(unreduced={'x'})")

    with self.assertRaisesRegex(
        TypeError, 'unreduced in `__add__` of PartitionSpec'):
      P('x', unreduced={'z'}) + (None,) * 2

    with self.assertRaisesRegex(
        TypeError, "unreduced in `__radd__` of PartitionSpec"):
      (None,) * 2 + P('x', unreduced={'y'})

    with self.assertRaisesRegex(
        ValueError, "partitions cannot overlap with unreduced"):
      P('x', 'y', unreduced={'x'})

    with self.assertRaisesRegex(
        ValueError, "partitions cannot overlap with unreduced"):
      P('x', None, 'y', unreduced={'z', 'y'})

  def test_named_sharding_unreduced_error(self):
    mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z'))

    with self.assertRaisesRegex(
        ValueError, "Unreduced axes.*not found in mesh.*"):
      NamedSharding(mesh, P('x', unreduced={'a'}))

    with self.assertRaisesRegex(
        ValueError, "Unreduced axes can only refer to mesh axes.*Explicit"):
      NamedSharding(mesh, P('x', unreduced={'y', 'z'}))

    with self.assertRaisesRegex(
        ValueError, "unreduced cannot contain None.*"):
      NamedSharding(mesh, P('x', unreduced={'y', None}))

  def test_hlo_sharding_get_axis_sizes(self):
    op = xc.OpSharding()
    op.type = xc.OpSharding.Type.OTHER
    op.tile_assignment_dimensions = [6, 35]
    op.iota_reshape_dims = [7, 10, 3]
    op.iota_transpose_perm = [2, 1, 0]
    s = GSPMDSharding(jax.devices(), op)
    self.assertIn('{devices=[6,35]<=[7,10,3]T(2,1,0)}', repr(s))
    self.assertEqual(s._to_xla_hlo_sharding(2).get_axis_sizes(), [7, 2, 5, 3])

  @parameterized.named_parameters(
      ('2d_mesh_x_y', (4, 2), P('x', 'y')),
      ('2d_mesh_x', (4, 2), P('x')),
      ('2d_mesh_y', (4, 2), P('y')),
      ('2d_mesh_none_y', (4, 2), P(None, 'y')),
      ('2d_mesh_none_x', (4, 2), P(None, 'x')),
      ('2d_mesh_xy', (4, 2), P(('x', 'y'))),
      ('2d_mesh_none_xy', (4, 2), P(None, ('x', 'y'))),
      ('2d_mesh_fully_replicated', (4, 2), P()),
      ('2d_mesh_x_none', (2, 1), P(('x',), None)),
      ('3d_mesh_none_none_z', (2, 2, 2), P(None, None, 'z')),
      ('3d_mesh_none_y_none', (2, 2, 2), P(None, 'y', None)),
      ('3d_mesh_x_y_none', (2, 2, 2), P('x', 'y', None)),
      ('3d_mesh_none_yz', (2, 2, 2), P(None, ('y', 'z'))),
      ('3d_mesh_x_none_yz', (2, 2, 2), P('x', None, ('y', 'z'))),
      ('3d_mesh_none_x_yz', (2, 2, 2), P(None, 'x', ('y', 'z'))),
      ('3d_mesh_xy_z', (2, 2, 2), P(('x', 'y'), 'z')),
      ('3d_mesh_xy_none_z', (2, 2, 2), P(('x', 'y'), None, 'z')),
      ('3d_mesh_x_y_z', (2, 2, 2), P('x', 'y', 'z')),
      ('3d_mesh_xz_y', (2, 2, 2), P(('x', 'z'), 'y')),
      ('3d_mesh_xz_none_y', (2, 2, 2), P(('x', 'z'), None, 'y')),
      ('3d_mesh_y_none_xz', (2, 2, 2), P('y', None, ('x', 'z'))),
      ('3d_mesh_none_y_xz', (2, 2, 2), P(None, 'y', ('x', 'z'))),
      ('3d_mesh2_none_none_z', (1, 2, 4), P(None, None, 'z')),
      ('3d_mesh2_x_none_none', (1, 2, 4), P('x', None, None)),
      ('3d_mesh_x_none_none', (2, 1, 1), P('x', None, None)),
  )
  def test_gspmd_sharding_shardy_lowering(self, mesh_shape, pspec):
    ndim = len(mesh_shape)
    mesh = jtu.create_mesh(
        mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z')
    )
    ns = jax.sharding.NamedSharding(mesh, pspec)
    gs = GSPMDSharding(ns._device_assignment, ns._to_xla_hlo_sharding(ndim))
    out_sdy_sharding = gs._to_sdy_sharding(ndim)
    self.assertTrue(out_sdy_sharding, ns._to_sdy_sharding(ndim))

  def test_nested_tuple_pspec_error(self):
    with self.assertRaisesRegex(
        ValueError,
        "A tuple inside PartitionSpec cannot contain a nested tuple"):
      jax.P('x', 'y', ('z', ('a',)))

    with self.assertRaisesRegex(
        ValueError,
        "A tuple inside PartitionSpec cannot contain a nested tuple"):
      jax.P((('a', 'b'), 'c'))

  def test_make_mesh_accelerate_explicit(self):
    if deprecations.is_accelerated('jax-make-mesh-default-explicit'):
      mesh = jax.make_mesh((1,), 'x')
      self.assertTupleEqual(mesh.axis_types, (AxisType.Explicit,))
    else:
      with self.assertWarnsRegex(DeprecationWarning, "The default axis_types"):
        mesh = jax.make_mesh((1,), 'x')
        self.assertTupleEqual(mesh.axis_types, (AxisType.Auto,))


class RngShardingTest(jtu.JaxTestCase):
  # tests that the PRNGs are automatically sharded as expected

  @parameterized.named_parameters(("3", 3), ("4", 4), ("5", 5))
  @jtu.skip_on_devices("gpu")
  def test_random_bits_is_pure_map_1d(self, num_devices):
    @jax.jit
    def f(x):
      bits = prng.threefry_random_bits(jnp.array([0, 0], dtype='uint32'),
                                       32, x.shape)
      return bits + x

    mesh = jtu.create_mesh((num_devices,), ('x',), iota_order=True)
    s = jax.sharding.NamedSharding(mesh, P('x'))

    n = num_devices ** 2
    global_x = jnp.arange(n).astype('uint32')
    x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])

    # check computation is fully partitioned and without any communication
    with jax.threefry_partitionable(True):
      unopt_txt = f.lower(x).as_text(dialect='hlo')
      opt_txt = f.lower(x).compile().as_text()
      self.assertIn(   f'[{n}]', unopt_txt)
      self.assertNotIn(f'[{n}]', opt_txt)
      self.assertNotIn('all-reduce', opt_txt)
      self.assertNotIn('collective-permute', opt_txt)

      # check against single-device reference
      y = f(x)
      y_ref1 = f(jax.device_put(x, jax.devices()[0]))
      self.assertArraysEqual(y, y_ref1)

  @parameterized.named_parameters(
      {"testcase_name": f"_{mesh_shape}_{pspec}",
       "mesh_shape": mesh_shape, "pspec": pspec}
      for mesh_shape in [(3, 2), (4, 2), (2, 3)]
      for pspec in [P('x', None), P(None, 'y'), P('x', 'y')])
  @jtu.skip_on_devices("gpu")
  def test_random_bits_is_pure_map_2d(self, mesh_shape, pspec):
    @jax.jit
    def f(x):
      bits = prng.threefry_random_bits(jnp.array([0, 0], dtype='uint32'),
                                       32, x.shape)
      return bits + x

    global_shape = tuple(np.square(mesh_shape))

    mesh = jtu.create_mesh(mesh_shape, ('x', 'y'), iota_order=True)
    s = jax.sharding.NamedSharding(mesh, pspec)

    n = math.prod(global_shape)
    global_x = np.arange(n).astype('uint32').reshape(global_shape)
    x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])

    # check computation is fully partitioned and without any communication
    with jax.threefry_partitionable(True):
      unopt_txt = f.lower(x).as_text(dialect='hlo')
      opt_txt = f.lower(x).compile().as_text()
      global_shape_fmt = ','.join(str(x) for x in global_shape)
      self.assertIn(   f'[{global_shape_fmt}]', unopt_txt)
      self.assertNotIn(f'[{global_shape_fmt}]', opt_txt)
      self.assertNotIn('all-reduce', opt_txt)
      self.assertNotIn('collective-permute', opt_txt)

      # check against single-device reference
      y = f(x)
      y_ref1 = f(jax.device_put(x, jax.devices()[0]))
      self.assertArraysEqual(y, y_ref1)

  def test_empty_mesh_creation(self):
    mesh = jax.sharding.Mesh(devices=np.empty((), dtype=object), axis_names=[])
    self.assertTrue(mesh.empty)
    self.assertEqual(mesh.size, 0)

    abstract_mesh = mesh.abstract_mesh
    self.assertTrue(abstract_mesh.empty)
    self.assertEqual(abstract_mesh.size, 0)

    abstract_mesh2 = jax.sharding.AbstractMesh((), ())
    self.assertTrue(abstract_mesh2.empty)
    self.assertEqual(abstract_mesh2.size, 0)


if __name__ == '__main__':
  absltest.main(testLoader=jtu.JaxTestLoader())
